# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn
from loguru import logger


def test_mochi_diffusers_pipeline():
    """
    Test that loads the Diffusers Mochi pipeline and executes it on a prompt.
    This test is useful for investigating the internal workings of the pipeline.
    """
    try:
        from diffusers import MochiPipeline
        from diffusers.utils import export_to_video
    except ImportError:
        pytest.skip("diffusers library not available or MochiPipeline not found")

    # Set device and dtype
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32

    logger.info(f"Loading Mochi pipeline on device: {device} with dtype: {torch_dtype}")

    # Load the Mochi pipeline
    pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch_dtype)
    pipe.to(device)

    # Print out the constituent models of the pipeline
    logger.info("=== Mochi Pipeline Constituent Models ===")
    for attr_name in dir(pipe):
        if not attr_name.startswith("_"):
            try:
                attr_value = getattr(pipe, attr_name)
                if hasattr(attr_value, "__class__") and hasattr(attr_value.__class__, "__module__"):
                    # Check if it's likely a model component (has parameters or is a neural network module)
                    if (
                        hasattr(attr_value, "parameters")
                        or "torch.nn" in str(type(attr_value))
                        or "diffusers" in str(type(attr_value))
                    ):
                        logger.info(f"  {attr_name}: {type(attr_value)} - {attr_value.__class__.__module__}")
                        # Print additional info for key components
                        if hasattr(attr_value, "config"):
                            try:
                                config_keys = (
                                    list(attr_value.config.keys()) if hasattr(attr_value.config, "keys") else "N/A"
                                )
                                logger.info(f"    Config keys: {config_keys}")
                            except Exception as e:
                                logger.info(f"    Config keys: Error accessing config - {e}")
            except AttributeError as e:
                # Skip attributes that raise AttributeError when accessed
                logger.debug(f"  Skipped {attr_name}: {e}")
            except Exception as e:
                logger.debug(f"  Error accessing {attr_name}: {e}")

    # Also check the pipeline's components attribute if it exists
    if hasattr(pipe, "components"):
        logger.info("\n=== Pipeline Components Dictionary ===")
        for component_name, component in pipe.components.items():
            logger.info(f"  {component_name}: {type(component)}")
            if hasattr(component, "config"):
                try:
                    logger.info(f"    Config: {type(component.config)}")
                except Exception as e:
                    logger.info(f"    Config: Error accessing config - {e}")

    # Enable memory optimizations if on GPU
    if device == "cuda":
        pipe.enable_model_cpu_offload()
        pipe.enable_vae_tiling()

    # Define test prompt
    prompt = "A close-up of a beautiful butterfly landing on a flower, wings gently moving in the breeze."

    logger.info(f"Generating video with prompt: '{prompt}'")

    # Generate frames with reduced parameters for faster testing
    frames = pipe(
        prompt,
        num_inference_steps=10,  # Reduced for faster testing
        guidance_scale=3.5,
        num_frames=16,  # Reduced for faster testing
        height=320,  # Reduced resolution for faster testing
        width=320,  # Reduced resolution for faster testing
    ).frames[0]

    # Validate output
    assert frames is not None, "No frames were generated by the pipeline"
    assert len(frames) > 0, "Empty frames list generated by the pipeline"
    # assert len(frames) == 16, f"Expected 16 frames, got {len(frames)}"

    # Check frame properties
    first_frame = frames[0]
    logger.info(f"Generated {len(frames)} frames, first frame size: {first_frame.size}")

    # Optional: Export to video file for manual inspection
    # Uncomment the following lines if you want to save the output
    export_to_video(frames, "mochi_test_output.mp4", fps=8)
    logger.info("Video exported to mochi_test_output.mp4")

    logger.info("Mochi pipeline test completed successfully!")


@pytest.mark.parametrize(
    "mesh_device, sp_axis, tp_axis, num_links",
    [
        [(4, 8), 1, 0, 4],
    ],
    ids=[
        "4x8sp1tp0",
    ],
    indirect=["mesh_device"],
)
@pytest.mark.parametrize("device_params", [{"fabric_config": ttnn.FabricConfig.FABRIC_1D}], indirect=True)
def test_tt_mochi_pipeline(
    mesh_device: ttnn.MeshDevice,
    sp_axis: int,
    tp_axis: int,
    num_links: int,
):
    """
    Test that creates the modified TT MochiPipeline and runs it on a prompt.
    This uses the TT transformer instead of the diffusers one.
    """
    try:
        from ....pipelines.mochi.pipeline_mochi import MochiPipeline as TTMochiPipeline
        from ....parallel.config import DiTParallelConfig, MochiVAEParallelConfig, ParallelFactor
    except ImportError as e:
        pytest.skip(f"Required TT modules not available: {e}")

    sp_factor = tuple(mesh_device.shape)[sp_axis]
    tp_factor = tuple(mesh_device.shape)[tp_axis]

    logger.info(f"Creating TT Mochi pipeline on mesh device with shape {mesh_device.shape}")
    logger.info(f"SP factor: {sp_factor}, TP factor: {tp_factor}")

    # Create parallel config
    parallel_config = DiTParallelConfig(
        cfg_parallel=ParallelFactor(factor=1, mesh_axis=0),
        tensor_parallel=ParallelFactor(factor=tp_factor, mesh_axis=tp_axis),
        sequence_parallel=ParallelFactor(factor=sp_factor, mesh_axis=sp_axis),
    )
    h_parallel_factor = 4
    vae_parallel_config = MochiVAEParallelConfig(
        time_parallel=ParallelFactor(factor=mesh_device.shape[0], mesh_axis=0),
        h_parallel=ParallelFactor(factor=h_parallel_factor, mesh_axis=1),
        w_parallel=ParallelFactor(factor=mesh_device.shape[1] // h_parallel_factor, mesh_axis=1),
    )
    assert vae_parallel_config.h_parallel.factor * vae_parallel_config.w_parallel.factor == mesh_device.shape[1]
    assert vae_parallel_config.h_parallel.mesh_axis == vae_parallel_config.w_parallel.mesh_axis

    # Create the TT Mochi pipeline
    tt_pipe = TTMochiPipeline(
        mesh_device=mesh_device,
        parallel_config=parallel_config,
        vae_parallel_config=vae_parallel_config,
        num_links=num_links,
        use_cache=True,
        use_reference_vae=False,
        model_name="genmo/mochi-1-preview",
    )

    # Define test prompt (same as the diffusers test)
    prompt = "A close-up of a beautiful butterfly landing on a flower, wings gently moving in the breeze."

    logger.info(f"Generating video with TT pipeline using prompt: '{prompt}'")

    # Generate frames with reduced parameters for faster testing
    frames = tt_pipe(
        prompt,
        num_inference_steps=50,  # Reduced for faster testing
        guidance_scale=3.5,
        num_frames=168,  # Reduced for faster testing
        height=480,  # Reduced resolution for faster testing
        width=848,  # Reduced resolution for faster testing
    ).frames[0]

    # Validate output
    assert frames is not None, "No frames were generated by the TT pipeline"
    assert len(frames) > 0, "Empty frames list generated by the TT pipeline"

    # Check frame properties
    first_frame = frames[0]
    logger.info(f"TT Pipeline generated {len(frames)} frames, first frame size: {first_frame.size}")

    # Optional: Export to video file for comparison
    try:
        from diffusers.utils import export_to_video

        export_to_video(frames, "tt_mochi_test_output.mp4", fps=30)
        logger.info("TT Pipeline video exported to tt_mochi_test_output.mp4")
    except ImportError:
        logger.info("Could not export video - diffusers.utils.export_to_video not available")

    logger.info("TT Mochi pipeline test completed successfully!")
